#!/usr/bin/python2.6
# Copyright 2011 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A utility class to process ETW logs and distill them to discrete metrics.
"""

import etw
import etw.descriptors.pagefault as pagefault
import etw.descriptors.process as process
import logging
import os.path
import re
import trace_event

# This import is required to have some setup performed behind the scenes, but
# we don't ever directly refer to it. Ignore pylint complaining about an unused
# import.
# pylint: disable=W0611
import etw.descriptors.pagefault_xp as pagefault_xp

# TODO(siggi): Make these configurable?
_CHROME_RE = re.compile(r'^chrome\.exe$', re.I)
_MODULES_TO_TRACK = set(['chrome.exe', 'chrome.dll'])


# These are used as bucket names in hard/soft fault accounting dicts.
# The names are chosen so that they sort before common module names, making
# the output more readable.
_ALL = '*ALL*'
_OTHER = '*OTHER*'


# A list of softfault names keyed by their types.
_SOFTFAULT_TYPES = {
  pagefault.Event.AccessViolation: 'AccessViolation',
  pagefault.Event.CopyOnWrite: 'CopyOnWrite',
  pagefault.Event.DemandZeroFault: 'DemandZero',
  pagefault.Event.GuardPageFault: 'GuardPage',
  pagefault.Event.TransitionFault: 'Transition'
}


# Set up a file-local logger.
_LOGGER = logging.getLogger(__name__)


# The Chrome event corresponding to the start of the main message loop.
_MESSAGE_LOOP_BEGIN = 'BrowserMain:MESSAGE_LOOP'


def _MakeEmptySoftFaultDict():
  return dict((ftype, 0) for ftype in _SOFTFAULT_TYPES.itervalues())


class LogEventCounter(etw.EventConsumer):
  """A utility class to parse salient metrics from ETW logs."""

  def __init__(self, file_db, module_db, process_db):
    """Initialize a log event counter.

    Args:
        file_db: an etw_db.FileNameDatabase instance.
        module_db: an etw_db.ModuleDatabase instance.
        process_db: an etw_db.ProcessThreadDatabase instance.
    """
    # etw.EventConsumer is an old-style class, so super() doesn't work.
    etw.EventConsumer.__init__(self)

    self._file_db = file_db
    self._module_db = module_db
    self._process_db = process_db
    self._message_loop_begin = []
    self._process_launch = []

    # Initialize the fault counting structure. We count faults per module
    # being tracked in _MODULES_TO_TRACK. Soft-faults are further classified
    # based on their type. We have to initialize all possible values to zero
    # so that an output is always generated by the benchmark, even if no
    # events were observed.
    self._hardfaults = {}
    self._softfaults = {}
    for module_name in _MODULES_TO_TRACK.union([_OTHER]):
      self._hardfaults[module_name] = 0
      self._softfaults[module_name] = _MakeEmptySoftFaultDict()

  @etw.EventHandler(trace_event.Event.EVENT_BEGIN)
  def _OnBegin(self, event):
    if event.name == _MESSAGE_LOOP_BEGIN:
      self._message_loop_begin.append(event.time_stamp)

  @etw.EventHandler(trace_event.Event.EVENT_END)
  def _OnEnd(self, event):
    pass

  @etw.EventHandler(trace_event.Event.EVENT_INSTANT)
  def _OnInstant(self, event):
    pass

  @etw.EventHandler(process.Event.Start)
  def _OnProcessStart(self, event):
    if _CHROME_RE.search(event.ImageFileName):
      self._process_launch.append(event.time_stamp)

  def FinalizeCounts(self):
    """This is meant to be called when all events have been processed.
    It will create a module name which contains the sum of soft-faults
    per module, and across all modules. It will also calculate a sum
    of all hard-faults across all modules."""
    # Sum up the hard-faults across modules.
    self._hardfaults[_ALL] = sum(self._hardfaults.itervalues())

    # Create a catch-all module that sums each soft-fault type across all
    # modules.
    all_modules = _MakeEmptySoftFaultDict()
    for (dummy_name, counts) in self._softfaults.iteritems():
      for (fault_type, count) in counts.iteritems():
        all_modules[fault_type] += count
    self._softfaults[_ALL] = all_modules

    # For each module create a bucket that holds the sum of all soft-fault
    # counts.
    for (dummy_name, counts) in self._softfaults.iteritems():
      counts[_ALL] = sum(counts.itervalues())

  def _GetModuleName(self, process_desc, fault):
    """Given a fault event, and the process in which it occurred, tries to
    resolve the module to which the event belongs. If no matching module is
    found, returns the catch-all module name _OTHER."""
    module = self._module_db.GetProcessModuleAt(
        process_desc.process_id, fault.VirtualAddress)
    if module:
      basename = os.path.basename(module.file_name).lower()
      if basename in _MODULES_TO_TRACK:
        return basename

    return _OTHER

  @etw.EventHandler(pagefault.Event.HardFault)
  def _OnHardFault(self, event):
    # Resolve the thread id in the event back to the faulting process.
    process_desc = self._process_db.GetThreadProcess(event.TThreadId)
    if process_desc and _CHROME_RE.search(process_desc.image_file_name):
      module_name = self._GetModuleName(process_desc, event)
      self._hardfaults[module_name] += 1

  def _OnSoftFault(self, event, fault_type):
    def UpdateModuleCount(module_name):
      self._softfaults[module_name][fault_type] += 1

    # Resolve the faulting process.
    process_desc = self._process_db.GetProcess(event.process_id)
    if process_desc and _CHROME_RE.search(process_desc.image_file_name):
      module_name = self._GetModuleName(process_desc, event)
      UpdateModuleCount(module_name)

  @etw.EventHandler(pagefault.Event.AccessViolation)
  def _OnAccessViolation(self, event):
    fault_type = _SOFTFAULT_TYPES[pagefault.Event.AccessViolation]
    return self._OnSoftFault(event, fault_type)

  @etw.EventHandler(pagefault.Event.CopyOnWrite)
  def _OnCopyOnWrite(self, event):
    fault_type = _SOFTFAULT_TYPES[pagefault.Event.CopyOnWrite]
    return self._OnSoftFault(event, fault_type)

  @etw.EventHandler(pagefault.Event.DemandZeroFault)
  def _OnDemandZeroFault(self, event):
    fault_type = _SOFTFAULT_TYPES[pagefault.Event.DemandZeroFault]
    return self._OnSoftFault(event, fault_type)

  @etw.EventHandler(pagefault.Event.GuardPageFault)
  def _OnGuardPageFault(self, event):
    fault_type = _SOFTFAULT_TYPES[pagefault.Event.GuardPageFault]
    return self._OnSoftFault(event, fault_type)

  @etw.EventHandler(pagefault.Event.TransitionFault)
  def _OnTransitionFault(self, event):
    fault_type = _SOFTFAULT_TYPES[pagefault.Event.TransitionFault]
    return self._OnSoftFault(event, fault_type)
